/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef DISC_TRANSFORMS_TRANSFORM_SCHEDULE_H_
#define DISC_TRANSFORMS_TRANSFORM_SCHEDULE_H_

#include <map>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>

#include "mlir/disc/transforms/fusion_utils.h"

namespace mlir {
namespace disc_ral {

// PatternKind reprensets the category of a given schedule.
// For the same `PatternKind`, we may still have different schedule strategies
// for different shape range. We further use different tags to distinguish such
// schedules within the same category.
enum class PatternKind : int32_t { kNone, kGEMM };

enum class DeviceType { kCPU, kGPU, kNone };

struct DeviceInfo {
  int cc_major = -1;
  int cc_minor = -1;
  int sm_count = -1;
  int max_threads_per_sm = -1;
};

// Converts a pattern kind to its string representation.
std::string patternKindToString(PatternKind kind);

// Creates a pattern kind from its string representation.
PatternKind patternKindFromString(const std::string& str);

// PatternDescription collects everything needed to assign schedule for a give
// fusion pattern.
class PatternDescription {
 public:
  explicit PatternDescription(lmhlo::FusionOp op, FusionPattern& fusionPattern,
                              ShapeAnalysis& shapeAnalysis);

  // Returns the kind of this `PatternDescription`.
  PatternKind getPatternKind() const;

  // Returns the tags attached to this fusion pattern.
  std::string getPatternTagStr() const;

  // Returns the full pattern kind str + tag str.
  std::string getTaggedPatternStr() const;

  const std::set<std::string>& getPatternTagSet() const;

  DeviceType getPatternDeviceType() const;

  // Returns the fusion op this descriptor holds.
  lmhlo::FusionOp getFusionOp() { return op_; }

  // Returns the fusion pattern this descriptor holds.
  FusionPattern& getFusionPattern() { return fusionPattern_; }

  // Returns the fusion pattern this descriptor holds.
  ShapeAnalysis& getShapeAnalysis() { return shapeAnalysis_; }

 private:
  lmhlo::FusionOp op_;
  FusionPattern& fusionPattern_;
  ShapeAnalysis& shapeAnalysis_;
  PatternKind patternKind_;
  std::set<std::string> tagSet_;
  DeviceType deviceType_;
};

// The name of the default schedule factory for a pattern kind.
// The default schedule should always has the lowest priority and generates
// workable code for any candidate pattern.
extern const char* kDefaultScheduleFactoryTag;
constexpr const int kDefaultScheduleFactoryPriority = 0;
constexpr const int kParsedFromFileScheduleFactoryStartPriority = 10000;

// Factory used to assign specific schedule for the given PatternDescription
class ScheduleFactory {
 public:
  explicit ScheduleFactory(int64_t id, PatternKind kind,
                           ArrayRef<StringRef> tags, DeviceType deviceType);
  virtual ~ScheduleFactory() = default;

  // Returns true if the factory accepts the pattern at compile time.
  // Note that it's only an conservative check, which means pattern rejected
  // by this check is definitely not supported by this factory. We need to
  // inject runtime check logic to further protect the schedule due to missing
  // some information at compile time.
  virtual bool accept(PatternDescription&);

  // Returns true if the schedule can always generate workable code for the
  // given pattern once it passes the compile-time check.
  virtual bool noGuardCondition(PatternDescription&);

  // Builds the runtime guard ir to protect the shedule generated by this
  // factory.
  virtual LogicalResult buildGuardCondition(OpBuilder& b, Location loc,
                                            PatternDescription&, Value&);

  // Assign the transform schedule and attach it into the module op.
  // The pattern should be accepted by this factory and the guard condition
  // should be emitted before successfully.
  virtual LogicalResult assignSchedule(PatternDescription&, ModuleOp,
                                       DeviceInfo);

  // Returns the id this factory has.
  int64_t getId() { return id_; }

  // Returns the pattern kind this factory has.
  PatternKind getPatternKind() { return kind_; }

  // Returns the tag set this factory has.
  const std::set<std::string>& getTagSet() { return tagSet_; }

  // Returns the device type this factory corresponds to.
  DeviceType getDeviceType() { return deviceType_; }

 protected:
  // These are called by `accept`. No need to check device type as the kind and
  // tags already determine a unique target.
  virtual bool checkKindAndTags(PatternDescription&);
  virtual bool checkFusionPatternProperties(PatternDescription&);

 protected:
  int64_t id_;
  PatternKind kind_;
  std::set<std::string> tagSet_;
  DeviceType deviceType_;
};

class ScheduleFactoryWithNoGuard : public ScheduleFactory {
 public:
  using ScheduleFactory::ScheduleFactory;
  bool noGuardCondition(PatternDescription&) override { return true; };
};

using ScheduleFactoryPtr = std::unique_ptr<ScheduleFactory>;

// A registry for different schedule factories.
class ScheduleFactoryRegistry {
 public:
  // Returns the singleton
  static ScheduleFactoryRegistry& get();

  // Returns next available id for schedule factory.
  static int64_t getNextUniqueId();

  // Inserts the new `ScheduleFactory`. Returns true if inserted, otherwise
  // false. The larger the `prioirty`, the larger the chance the factory being
  // choosed. Note that we do not allow to assign the same priority for two
  // different factories with the same pattern kind.
  bool registerScheduleFactory(PatternKind kind, int priority,
                               ScheduleFactoryPtr factory);
  void unregisterScheduleFactory(PatternKind kind, int priority);

  // Returns the schedule factory with the highes priority for `pd`.
  // Returns nullptr if not found.
  ScheduleFactory* getScheduleFactoryWithHighestPriority(
      PatternDescription& pd);

  // Returns all suitable schedule factories for `pd`. The returned factory list
  // is sorted by `priority`. The first one has the highest priority.
  SmallVector<ScheduleFactory*> getAllCandidateScheduleFactories(
      PatternDescription& pd);

 private:
  ScheduleFactoryRegistry() = default;
  std::unordered_map<PatternKind, std::map<int, ScheduleFactoryPtr>>
      patternMap_;
};

// Macros used to define disc transform schedule factory.
#define DISC_TRANSFORM_SCHEDULE(kind, priority, T, ...)               \
  DISC_TRANSFORM_SCHEDULE_UNIQ_HELPER(__COUNTER__, kind, priority, T, \
                                      __VA_ARGS__)

#define DISC_TRANSFORM_SCHEDULE_UNIQ_HELPER(ctr, kind, priority, T, ...) \
  DISC_TRANSFORM_SCHEDULE_UNIQ(ctr, kind, priority, T, __VA_ARGS__)

#define DISC_TRANSFORM_SCHEDULE_UNIQ(ctr, kind, priority, T, ...)             \
  static bool unused_ret_val_##ctr = []() {                                   \
    bool ret = ::mlir::disc_ral::ScheduleFactoryRegistry::get()               \
                   .registerScheduleFactory(                                  \
                       kind, priority,                                        \
                       std::make_unique<T>(                                   \
                           ::mlir::disc_ral::ScheduleFactoryRegistry::        \
                               getNextUniqueId(),                             \
                           kind, __VA_ARGS__));                               \
    if (!ret) ::llvm::dbgs() << "failed to register a new scheduleFactory\n"; \
    return ret;                                                               \
  }();

// Assign schedule for the given PatternDescription according to its kind and
// tag.
class ScheduleDispatcher {
 public:
  // Users may override the schedule by providing its own implementation and
  // pass the schedule files to the dispatcher.
  // Format of `transformFileName`:
  //  "<kind-0>:<tag-str-0>:<filename-0>;<kind-1>:<tag-str-1>:<filename-1>;"
  explicit ScheduleDispatcher(const std::string& transformFileName);
  ~ScheduleDispatcher();

  // Attaches a schedule for the given pattern description.
  LogicalResult dispatch(PatternDescription& pd, ModuleOp m);

  // Parses schedule modules from the given files.
  LogicalResult parseModuleFromFile(MLIRContext* ctx);

  void setDeviceInfo(const DeviceInfo& deviceInfo) { deviceInfo_ = deviceInfo; }
  const DeviceInfo& getDeviceInfo() { return deviceInfo_; }

 private:
  std::string transformFileName_;
  // <pattern-kind, <tag-str, module-op>>
  std::unordered_map<PatternKind,
                     std::unordered_map<std::string, OwningOpRef<ModuleOp>>>
      parsedModuleMap_;
  DeviceInfo deviceInfo_;
};

}  // namespace disc_ral
}  // namespace mlir

#endif  // DISC_TRANSFORMS_TRANSFORM_SCHEDULE_H_
